# Required Libraries
import pandas as pd
import plotly.express as px
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation

# Load Data
comments_data = pd.read_csv('data/state_mentions_comments_nlp.csv', low_memory=False)
submissions_data = pd.read_csv('data/state_mentions_submissions_nlp.csv', low_memory=False)

# Preprocess Data
comments_data['processed_text'] = comments_data['body_lower']
submissions_data['processed_text'] = submissions_data['title_lower']
combined_data = pd.concat([
    comments_data[['processed_text', 'states_mentioned']],
    submissions_data[['processed_text', 'states_mentioned']]
], ignore_index=True)

# Topic Modeling
vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
lda_model = LatentDirichletAllocation(n_components=4, random_state=42)
sampled_data = combined_data.sample(n=10000, random_state=42)  # Sample data
X_sampled = vectorizer.fit_transform(sampled_data['processed_text'].fillna(''))
lda_model.fit(X_sampled)
topic_results = lda_model.transform(X_sampled)
sampled_data['Topic'] = topic_results.argmax(axis=1)

# Flatten state mentions for aggregation
sampled_data = sampled_data.explode('states_mentioned')

# Aggregate and identify most common topics for each state
state_topic_summary = sampled_data.groupby(['states_mentioned', 'Topic']).size().reset_index(name='Count')
most_common_topics = state_topic_summary.loc[
    state_topic_summary.groupby('states_mentioned')['Count'].idxmax()
]

# Map numeric topics to meaningful descriptions
topic_descriptions = {
    0: "Tourism and Travel",
    1: "Politics and Governance",
    2: "Lifestyle and Culture",
    3: "Economy and Business"
}
most_common_topics['Topic_Label'] = most_common_topics['Topic'].map(topic_descriptions)

# Add state abbreviations using a mapping
state_to_abbreviation = {
    "alabama": "AL", "alaska": "AK", "arizona": "AZ", "arkansas": "AR", "california": "CA",
    "colorado": "CO", "connecticut": "CT", "delaware": "DE", "florida": "FL", "georgia": "GA",
    "hawaii": "HI", "idaho": "ID", "illinois": "IL", "indiana": "IN", "iowa": "IA",
    "kansas": "KS", "kentucky": "KY", "louisiana": "LA", "maine": "ME", "maryland": "MD",
    "massachusetts": "MA", "michigan": "MI", "minnesota": "MN", "mississippi": "MS",
    "missouri": "MO", "montana": "MT", "nebraska": "NE", "nevada": "NV", "new hampshire": "NH",
    "new jersey": "NJ", "new mexico": "NM", "new york": "NY", "north carolina": "NC",
    "north dakota": "ND", "ohio": "OH", "oklahoma": "OK", "oregon": "OR", "pennsylvania": "PA",
    "rhode island": "RI", "south carolina": "SC", "south dakota": "SD", "tennessee": "TN",
    "texas": "TX", "utah": "UT", "vermont": "VT", "virginia": "VA", "washington": "WA",
    "west virginia": "WV", "wisconsin": "WI", "wyoming": "WY"
}

most_common_topics['State_Abbreviation'] = most_common_topics['states_mentioned'].str.strip("[]").str.replace("'", "").str.lower().map(state_to_abbreviation)

# Interactive Choropleth Map
fig = px.choropleth(
    most_common_topics,
    locations='State_Abbreviation',  # Use state abbreviations
    locationmode='USA-states',  # Match to state abbreviations
    color='Topic_Label',  # Color by topic labels
    hover_name='State_Abbreviation',  # Display state abbreviation on hover
    hover_data={'Topic_Label': True},  # Show topic and count on hover
    title='Most Common Topic Per State',
    scope='usa',  # Restrict map to the USA
    labels={
        'Topic_Label': 'Most Common Topic',
        'State_Abbreviation': 'State'  # Change hover label as "State"
    }
)

fig.update_layout(
    geo=dict(lakecolor='rgb(255, 255, 255)'),
    width=1000,
    height=600
)
fig.show()
fig.write_html("../../website-source/plots/topic_per_state.html")